Skip to content

[Attention] Optimize FlashInfer MetadataBuilder Build call #21137

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from

Conversation

LucasWilkinson
Copy link
Collaborator

@LucasWilkinson LucasWilkinson commented Jul 17, 2025

Essential Elements of an Effective PR Description Checklist

  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.

Purpose

Flash infer prefers host side CPU buffers in many cases, example: https://github.com/flashinfer-ai/flashinfer/blob/3c40456effae8b9c5b1a11c0d1e0594295b1a312/flashinfer/prefill.py#L1430-L1436

So we pass host side buffers (since #20466 we now have access to these) to reduce D2H transfers.

Trace from main showing D2H transfers in plan

image

Test Plan

Test Result

Accuracy Results

VLLM_ATTENTION_BACKEND=FLASHINFER lm_eval --model vllm --model_args pretrained=met
a-llama/Meta-Llama-3-8B-Instruct --tasks gsm8k --batch_size auto
...
INFO 07-17 20:33:43 [cuda.py:253] Using FlashInfer backend on V1 engine.
...
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.7536|±  |0.0119|
|     |       |strict-match    |     5|exact_match|↑  |0.7551|±  |0.0118|

Benchmark Results

Benchmark Command:

python benchmarks/benchmark_throughput.py --model meta-llama/Llama-3.2-3B-Instruct --dataset-name random --input-len 256 --output-len 128 --num-prompts <N> --seed 42

Results (3 runs per condition, mean ± standard error):

num-prompts Main Branch (req/s) This PR (req/s)
1 1.58 ± 0.06 1.90 ± 0.03
8 13.06 ± 0.11 14.32 ± 0.21
16 26.00 ± 0.07 28.74 ± 0.13
32 47.84 ± 0.57 46.53 ± 1.57
64 76.14 ± 0.45 81.43 ± 3.43
128 116.99 ± 6.10 127.78 ± 7.50
256 164.45 ± 6.12 177.70 ± 3.88

Tested on NVIDIA B200 GPU with meta-llama/Llama-3.2-3B-Instruct (256→128 tokens)

(Optional) Documentation Update

Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@mergify mergify bot added rocm Related to AMD ROCm speculative-decoding labels Jul 17, 2025
Copy link

mergify bot commented Jul 17, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @LucasWilkinson.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request is a significant and well-executed refactoring of the attention backend infrastructure. The primary goal of decoupling the metadata builders from the model runner has been achieved, which improves modularity and maintainability. The optimization for FlashInfer by preparing metadata on the CPU is a key improvement and has been implemented correctly.

The introduction of CommonAttentionMetadata as a unified data structure is a solid design choice that simplifies the data flow to the attention backends. The refactoring of the speculative decoding logic, particularly in vllm/v1/spec_decode/eagle.py, to remove the Triton kernel in favor of a more readable PyTorch/NumPy implementation is a notable improvement.

The addition of a comprehensive test suite in tests/v1/attention/test_attention_backends.py is excellent. It provides strong validation for the correctness of this large-scale refactoring by comparing various backends against a reference implementation under realistic conditions.

Overall, the changes are of high quality and represent a positive step forward for the codebase. I have not identified any issues of high or critical severity.

Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>

Optimize V1 FlashInfer backend to use CPU host buffers

- Replace GPU-to-CPU transfers with direct CPU tensor construction
- Build planning tensors from existing CommonAttentionMetadata CPU buffers
- Reduce from 6x to 1x .cpu() calls during FlashInfer planning
- Fix test mocks to handle correct argument count
- Maintain compatibility with GPUModelRunner and FlashInfer V1 backend

Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>

dont transfer block table

Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>

optimize

Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
@LucasWilkinson LucasWilkinson force-pushed the lwilkinson/flash-infer-host-buffers branch from 87ccacf to 8af5f3b Compare July 18, 2025 00:36
@mergify mergify bot removed the needs-rebase label Jul 18, 2025
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Copy link
Collaborator

@WoosukKwon WoosukKwon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BTW why don't we use Numpy instead of PyTorch CPU tensors? Except for some edge cases, Numpy is usually faster in my experience.

@fhl2000
Copy link
Contributor

fhl2000 commented Jul 18, 2025

Could we still pass the device tensors to Flashinfer's plan() rather than host tensors? Because we might want to support full cudagraph of Flashinfer in the future (currently implemented in #20059 in rough), which requires managing device-side persistent buffers that can be reused across different decode wrappers. Here, one decode wrapper corresponds to a runtime shape that needs to be captured.

Also, if we pass the host tensors to the wrapper, it seems that H2D transfers still exist. If I remember correctly, Sglang's implementation overrides the plan functions that still pass host-side persistent buffers, and also explicitly avoids certain D2H transfers.

Hope it's helpful! @LucasWilkinson

@LucasWilkinson
Copy link
Collaborator Author

BTW why don't we use Numpy instead of PyTorch CPU tensors? Except for some edge cases, Numpy is usually faster in my experience.

Ive found going to and from numpy (i.e. .numpy(), torch::from_numpy can be a bit slow and only worth it if you are gonna do alot of ups; since FlashInfer ultimately wants torch tensors and for most of these theres only one or two ops per tensor im not sure its worth going to numpy; but I can scrub for tensors that are manipulated alot

@LucasWilkinson
Copy link
Collaborator Author

LucasWilkinson commented Jul 18, 2025

Could we still pass the device tensors to Flashinfer's plan() rather than host tensors? Because we might want to support full cudagraph of Flashinfer in the future (currently implemented in #20059 in rough), which requires managing device-side persistent buffers that can be reused across different decode wrappers. Here, one decode wrapper corresponds to a runtime shape that needs to be captured.

If you look in FlashInfer's BatchDecodeWithPagedKVCacheWrapper you'll see the buffers get copied in the cudagraph path regardless: https://github.com/flashinfer-ai/flashinfer/blob/1e9a41ad7f0efc5989bb0a2bf7e954902c8c73af/flashinfer/decode.py#L892-L910 and will get copied to the host: https://github.com/flashinfer-ai/flashinfer/blob/1e9a41ad7f0efc5989bb0a2bf7e954902c8c73af/flashinfer/decode.py#L925-L926

Also, if we pass the host tensors to the wrapper, it seems that H2D transfers still exist.

Yes; however H2D transfers are preferred over D2H as they can be done in a non-blocking fashion and do force synchronization with GPU. For the build call we are trying to optimize the CPU overhead so the fire-and-forget nature of the H2D transfers is better then depending on D2H transfer.

If I remember correctly, Sglang's implementation overrides the plan functions that still pass host-side persistent buffers, and also explicitly avoids certain D2H transfers.

Thats effectively what this PR does; the CPU buffers in CommonAttentionMetadata are views into the gpu_model_runners persistent input_batch host side tensors.

@fhl2000
Copy link
Contributor

fhl2000 commented Jul 18, 2025

If I remember correctly, Sglang's implementation overrides the plan functions that still pass host-side persistent buffers,

Oh my bad! Sorry, I was saying they are passing the device-side buffers.

If you look in FlashInfer's BatchDecodeWithPagedKVCacheWrapper you'll see the buffers get copied in the cudagraph path regardless: https://github.com/flashinfer-ai/flashinfer/blob/1e9a41ad7f0efc5989bb0a2bf7e954902c8c73af/flashinfer/decode.py#L892-L910 and will get copied to the host: https://github.com/flashinfer-ai/flashinfer/blob/1e9a41ad7f0efc5989bb0a2bf7e954902c8c73af/flashinfer/decode.py#L925-L926

I am wondering if we can override this plan function that lets the wrapper directly own the device-side persistent buffer from VLLM, and avoid any unnecessary copy (device-to-device or host-to-device)? At least for qo_indptr, which is equivalent to query_start_loc, we already have both cpu and gpu versions of it from common_attn_metadata, so we can just reuse them without any further copy.

@LucasWilkinson
Copy link
Collaborator Author

LucasWilkinson commented Jul 18, 2025

If I remember correctly, Sglang's implementation overrides the plan functions that still pass host-side persistent buffers,

Oh my bad! Sorry, I was saying they are passing the device-side buffers.

If you look in FlashInfer's BatchDecodeWithPagedKVCacheWrapper you'll see the buffers get copied in the cudagraph path regardless: https://github.com/flashinfer-ai/flashinfer/blob/1e9a41ad7f0efc5989bb0a2bf7e954902c8c73af/flashinfer/decode.py#L892-L910 and will get copied to the host: https://github.com/flashinfer-ai/flashinfer/blob/1e9a41ad7f0efc5989bb0a2bf7e954902c8c73af/flashinfer/decode.py#L925-L926

I am wondering if we can override this plan function that lets the wrapper directly own the device-side persistent buffer from VLLM, and avoid any unnecessary copy (device-to-device or host-to-device)? At least for qo_indptr, which is equivalent to query_start_loc, we already have both cpu and gpu versions of it from common_attn_metadata, so we can just reuse them without any further copy.

Is this what you are referring to? https://github.com/sgl-project/sglang/blob/719b29f218a09642193c4bda2a7ffa32829d5604/python/sglang/srt/layers/attention/flashinfer_backend.py#L1229 ?; not that familiar with sglang. This is an interesting idea; thanks for sharing! Regardless, even in this overridden version they pass host side buffers (https://github.com/sgl-project/sglang/blob/719b29f218a09642193c4bda2a7ffa32829d5604/python/sglang/srt/layers/attention/flashinfer_backend.py#L1334-L1336); so if we want to override plan in the future I think we would still want this PR as a stepping stone (and override plan in follow up PR).

@mgoin
Copy link
Member

mgoin commented Jul 18, 2025

Could you make sure to test the trtllm case in the flashinfer backend as well? Just want to make sure this choice is preferable for that backend as well if affected

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants